﻿/*
 * *******************************************************
 *
 * 作者：HZY
 *
 * 开源地址：https://gitee.com/hzy6
 *
 * *******************************************************
 */

namespace HZY.Framework.Core.Utils;

/// <summary>
/// 程序集工具
/// </summary>
public static class AssemblyUtil
{
    /// <summary>
    /// 所有程序集缓存对象
    /// </summary>
    /// <returns></returns>
    private static readonly List<Assembly> AssemblyList = new();

    /// <summary>
    /// 扫描所有程序集
    /// </summary>
    /// <param name="assemblyNames">需要扫描的程序集名称,可传入关键字,模糊匹配</param>
    /// <returns></returns>
    static IEnumerable<Assembly> ScanAllAssembly(params string[]? assemblyNames)
    {
        var result = new List<Assembly>();
        var ignoreList = new List<string>
        {
            nameof(Microsoft),
            nameof(System),
            "Swashbuckle",
            "Npgsql",
            "NPOI",
            "FreeSql",
            "Castle",
            "Azure",
            "netstandard"
        };

        #region 默认引用程序集

        var defaultAssembliesQuery = AssemblyLoadContext.Default.Assemblies
                       .Where(w => !ignoreList.Any(s => (w.GetName().Name ?? "").Contains(s)))//忽略
                       .Where(w => !result.Any(s => (w.GetName().Name ?? "").Contains(s.GetName().Name ?? "")))
                       ;

        List<Assembly> defaultAssembliess = defaultAssembliesQuery
               .ToList()
               ;

        if (assemblyNames != null && assemblyNames.Length > 0)
        {
            defaultAssembliess = defaultAssembliesQuery
               .Where(w => assemblyNames.Any(s => (w.GetName().Name ?? "").Contains(s)))
               .ToList()
            ;
        }

        result.AddRange(defaultAssembliess);

        #endregion

        #region 查找手动引用的程序集

        var entryAssembly = Assembly.GetEntryAssembly();
        if (entryAssembly == null) return result;
        //排除已经扫描到的程序集
        if (!result.Any(w => (entryAssembly.GetName().Name ?? "").Contains(w.GetName().Name ?? "")) &&
            !ignoreList.Any(w => (entryAssembly.GetName().Name ?? "").Contains(w))//忽略
        )
        {
            result.Add(entryAssembly);
        }

        var referencedAssemblieQuery = entryAssembly
            .GetReferencedAssemblies()
            .Where(w => !ignoreList.Any(s => (w.Name ?? "").Contains(s)))//忽略
            .Where(w => !result.Any(s => (w.Name ?? "").Contains(s.GetName().Name ?? "")))
        ;

        List<Assembly> referencedAssemblies = null;
        if (assemblyNames != null && assemblyNames.Length > 0)
        {
            referencedAssemblies = referencedAssemblieQuery
               .Where(w => assemblyNames.Any(s => (w.Name ?? "").Contains(s)))
               .Select(AssemblyLoadContext.Default.LoadFromAssemblyName)
               .ToList()
            ;
        }
        else
        {
            referencedAssemblies = referencedAssemblieQuery
               .Select(AssemblyLoadContext.Default.LoadFromAssemblyName)
               .ToList()
           ;
        }

        result.AddRange(referencedAssemblies);

        #endregion

        #region 将所有 dll 文件 重新载入 防止有未扫描到的 程序集

        var paths = Directory.GetFiles(AppDomain.CurrentDomain.BaseDirectory)
            .Where(w => w.EndsWith(".dll"))
            .Where(w => !ignoreList.Any(s => Path.GetFileName(w).Contains(s)))//忽略
            .Where(w => !result.Any(s => s.Location == w))//如果这个 dll 路径的已经被扫描则排除
        ;

        if (assemblyNames != null && assemblyNames.Length > 0)
        {
            paths = paths.Where(w => assemblyNames.Any(s => Path.GetFileName(w).Contains(s)));
        }

        foreach (var path in paths)
        {
            if (!File.Exists(path)) continue;

            try
            {
                var assembly = AssemblyLoadContext.Default.LoadFromAssemblyPath(path);
                if (result.Any(w => w.FullName == assembly.FullName))
                {
                    continue;
                }
                result.Add(assembly);
            }
            catch (Exception)
            {
                continue;
            }
        }

        #endregion

        return result;
    }

    /// <summary>
    /// 获取所有的 程序集
    /// </summary>
    /// <param name="assemblyNames">需要扫描的程序集名称,可传入关键字,模糊匹配</param>
    /// <returns></returns>
    public static IEnumerable<Assembly> GetAssemblyList(params string[]? assemblyNames)
    {
        if (AssemblyList.Count == 0)
        {
            AssemblyList.AddRange(ScanAllAssembly(assemblyNames));
        }

        if (assemblyNames != null && assemblyNames.Length > 0)
        {
            return AssemblyList.Where(w => assemblyNames.Where(s => !string.IsNullOrWhiteSpace(w.FullName)).Any(s => w.FullName!.Contains(s)));
        }

        return AssemblyList;
    }





}
